// This file is licensed under the Elastic License 2.0. Copyright 2021-present, StarRocks Limited.

#include "exec/pipeline/hashjoin/hash_join_build_operator.h"

#include "runtime/runtime_filter_worker.h"

namespace starrocks {
namespace pipeline {

HashJoinBuildOperator::HashJoinBuildOperator(OperatorFactory* factory, int32_t id, const string& name,
                                             int32_t plan_node_id, HashJoinerPtr hash_joiner, size_t driver_sequence,
                                             PartialRuntimeFilterMerger* partial_rf_merger,
                                             const TJoinDistributionMode::type distribution_mode,
                                             std::atomic<bool>& any_broadcast_builder_finished)
        : Operator(factory, id, name, plan_node_id),
          _hash_joiner(hash_joiner),
          _driver_sequence(driver_sequence),
          _partial_rf_merger(partial_rf_merger),
          _distribution_mode(distribution_mode),
          _any_broadcast_builder_finished(any_broadcast_builder_finished) {
    _hash_joiner->ref();
}

Status HashJoinBuildOperator::push_chunk(RuntimeState* state, const vectorized::ChunkPtr& chunk) {
    return _hash_joiner->append_chunk_to_ht(state, chunk);
}

Status HashJoinBuildOperator::prepare(RuntimeState* state) {
    RETURN_IF_ERROR(Operator::prepare(state));
    return _hash_joiner->prepare(state);
}
Status HashJoinBuildOperator::close(RuntimeState* state) {
    RETURN_IF_ERROR(_hash_joiner->unref(state));
    return Operator::close(state);
}

StatusOr<vectorized::ChunkPtr> HashJoinBuildOperator::pull_chunk(RuntimeState* state) {
    const char* msg = "pull_chunk not supported in HashJoinBuildOperator";
    CHECK(false) << msg;
    return Status::NotSupported(msg);
}

void HashJoinBuildOperator::set_finishing(RuntimeState* state) {
    _is_finished = true;
    _hash_joiner->build_ht(state);

    size_t merger_index = _driver_sequence;
    if (_distribution_mode == TJoinDistributionMode::BROADCAST) {
        // As for BROADCAST, only the first finished builder creates runtime filters.
        bool expected = false;
        if (!_any_broadcast_builder_finished.compare_exchange_strong(expected, true)) {
            _hash_joiner->enter_probe_phase();
            return;
        }

        merger_index = 0;
    }

    _hash_joiner->create_runtime_filters(state);

    auto ht_row_count = _hash_joiner->get_ht_row_count();
    auto& partial_in_filters = _hash_joiner->get_runtime_in_filters();
    auto& partial_bloom_filter_build_params = _hash_joiner->get_runtime_bloom_filter_build_params();
    auto& partial_bloom_filters = _hash_joiner->get_runtime_bloom_filters();
    // add partial filters generated by this HashJoinBuildOperator to PartialRuntimeFilterMerger to merge into a
    // total one.
    auto status = _partial_rf_merger->add_partial_filters(merger_index, ht_row_count, std::move(partial_in_filters),
                                                          std::move(partial_bloom_filter_build_params),
                                                          std::move(partial_bloom_filters));
    if (status.ok() && status.value()) {
        auto&& in_filters = _partial_rf_merger->get_total_in_filters();
        auto&& bloom_filters = _partial_rf_merger->get_total_bloom_filters();

        // publish runtime bloom-filters
        state->runtime_filter_port()->publish_runtime_filters(bloom_filters);
        // move runtime filters into RuntimeFilterHub.
        runtime_filter_hub()->set_collector(_plan_node_id, std::make_unique<RuntimeFilterCollector>(
                                                                   std::move(in_filters), std::move(bloom_filters)));
    }

    _hash_joiner->enter_probe_phase();
}

HashJoinBuildOperatorFactory::HashJoinBuildOperatorFactory(
        int32_t id, int32_t plan_node_id, HashJoinerFactoryPtr hash_joiner_factory,
        std::unique_ptr<PartialRuntimeFilterMerger>&& partial_rf_merger,
        const TJoinDistributionMode::type distribution_mode)
        : OperatorFactory(id, "hash_join_build", plan_node_id),
          _hash_joiner_factory(hash_joiner_factory),
          _partial_rf_merger(std::move(partial_rf_merger)),
          _distribution_mode(distribution_mode) {}

Status HashJoinBuildOperatorFactory::prepare(RuntimeState* state) {
    RETURN_IF_ERROR(OperatorFactory::prepare(state));
    return _hash_joiner_factory->prepare(state);
}

void HashJoinBuildOperatorFactory::close(RuntimeState* state) {
    _hash_joiner_factory->close(state);
    OperatorFactory::close(state);
}

OperatorPtr HashJoinBuildOperatorFactory::create(int32_t degree_of_parallelism, int32_t driver_sequence) {
    return std::make_shared<HashJoinBuildOperator>(
            this, _id, _name, _plan_node_id, _hash_joiner_factory->create(driver_sequence), driver_sequence,
            _partial_rf_merger.get(), _distribution_mode, _any_broadcast_builder_finished);
}
} // namespace pipeline
} // namespace starrocks
